{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pbVXrmEsLXDt"
   },
   "source": [
    "# ML-Agents run with Stable Baselines 3\n",
    "<img src=\"https://github.com/Unity-Technologies/ml-agents/blob/release/4.0.0/com.unity.ml-agents/Documentation~/images/image-banner.png?raw=true\" align=\"middle\" width=\"435\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WNKTwHU3d2-l"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#@title Install Rendering Dependencies { display-mode: \"form\" }\n",
    "#@markdown (You only need to run this code when using Colab's hosted runtime)\n",
    "\n",
    "import os\n",
    "from IPython.display import HTML, display\n",
    "\n",
    "def progress(value, max=100):\n",
    "    return HTML(\"\"\"\n",
    "        <progress\n",
    "            value='{value}'\n",
    "            max='{max}',\n",
    "            style='width: 100%'\n",
    "        >\n",
    "            {value}\n",
    "        </progress>\n",
    "    \"\"\".format(value=value, max=max))\n",
    "\n",
    "pro_bar = display(progress(0, 100), display_id=True)\n",
    "\n",
    "try:\n",
    "  import google.colab\n",
    "  INSTALL_XVFB = True\n",
    "except ImportError:\n",
    "  INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ\n",
    "\n",
    "if INSTALL_XVFB:\n",
    "  with open('frame-buffer', 'w') as writefile:\n",
    "    writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n",
    "XVFB=/usr/bin/Xvfb\n",
    "XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n",
    "PIDFILE=./frame-buffer.pid\n",
    "case \"$1\" in\n",
    "  start)\n",
    "    echo -n \"Starting virtual X frame buffer: Xvfb\"\n",
    "    /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n",
    "    echo \".\"\n",
    "    ;;\n",
    "  stop)\n",
    "    echo -n \"Stopping virtual X frame buffer: Xvfb\"\n",
    "    /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n",
    "    rm $PIDFILE\n",
    "    echo \".\"\n",
    "    ;;\n",
    "  restart)\n",
    "    $0 stop\n",
    "    $0 start\n",
    "    ;;\n",
    "  *)\n",
    "        echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n",
    "        exit 1\n",
    "esac\n",
    "exit 0\n",
    "    \"\"\")\n",
    "  !sudo apt-get update\n",
    "  pro_bar.update(progress(10, 100))\n",
    "  !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg\n",
    "  pro_bar.update(progress(20, 100))\n",
    "  !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n",
    "  pro_bar.update(progress(30, 100))\n",
    "  !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1\n",
    "  pro_bar.update(progress(40, 100))\n",
    "  !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n",
    "  pro_bar.update(progress(50, 100))\n",
    "  !sudo dpkg -i xvfb.deb 2>&1\n",
    "  pro_bar.update(progress(70, 100))\n",
    "  !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n",
    "  pro_bar.update(progress(80, 100))\n",
    "  !rm xvfb.deb\n",
    "  pro_bar.update(progress(90, 100))\n",
    "  !bash frame-buffer start\n",
    "  os.environ[\"DISPLAY\"] = \":1\"\n",
    "pro_bar.update(progress(100, 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pzj7wgapAcDs"
   },
   "source": [
    "### Installing ml-agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N8yfQqkbebQ5",
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "try:\n",
    "  import mlagents\n",
    "  print(\"ml-agents already installed\")\n",
    "except ImportError:\n",
    "  !python -m pip install -q mlagents==1.1.0\n",
    "  print(\"Installed ml-agents\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_u74YhSmW6gD"
   },
   "source": [
    "## Run the Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P-r_cB2rqp5x"
   },
   "source": [
    "### Import dependencies and set some high level parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YSf-WhxbqtLw"
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "from typing import Callable, Any\n",
    "\n",
    "import gym\n",
    "from gym import Env\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv\n",
    "from supersuit import observation_lambda_v0\n",
    "\n",
    "\n",
    "from mlagents_envs.environment import UnityEnvironment\n",
    "from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper\n",
    "from mlagents_envs.registry import UnityEnvRegistry, default_registry\n",
    "from mlagents_envs.side_channel.engine_configuration_channel import (\n",
    "    EngineConfig,\n",
    "    EngineConfigurationChannel,\n",
    ")\n",
    "\n",
    "NUM_ENVS = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Environment  and Engine Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Default values from CLI (See cli_utils.py)\n",
    "DEFAULT_ENGINE_CONFIG = EngineConfig(\n",
    "    width=84,\n",
    "    height=84,\n",
    "    quality_level=4,\n",
    "    time_scale=20,\n",
    "    target_frame_rate=-1,\n",
    "    capture_frame_rate=60,\n",
    ")\n",
    "\n",
    "# Some config subset of an actual config.yaml file for MLA.\n",
    "@dataclass\n",
    "class LimitedConfig:\n",
    "    # The local path to a Unity executable or the name of an entry in the registry.\n",
    "    env_path_or_name: str\n",
    "    base_port: int\n",
    "    base_seed: int = 0\n",
    "    num_env: int = 1\n",
    "    engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG\n",
    "    visual_obs: bool = False\n",
    "    # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.\n",
    "    # WARNING: Make sure to use MultiInputPolicy if you turn this on.\n",
    "    allow_multiple_obs: bool = False\n",
    "    env_registry: UnityEnvRegistry = default_registry"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Unity Environment SB3 Factory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _unity_env_from_path_or_registry(\n",
    "    env: str, registry: UnityEnvRegistry, **kwargs: Any\n",
    ") -> UnityEnvironment:\n",
    "    if Path(env).expanduser().absolute().exists():\n",
    "        return UnityEnvironment(file_name=env, **kwargs)\n",
    "    elif env in registry:\n",
    "        return registry.get(env).make(**kwargs)\n",
    "    else:\n",
    "        raise ValueError(f\"Environment '{env}' wasn't a local path or registry entry\")\n",
    "        \n",
    "def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:\n",
    "    def handle_obs(obs, space):\n",
    "        if isinstance(space, gym.spaces.Tuple):\n",
    "            if len(space) == 1:\n",
    "                return obs[0]\n",
    "            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n",
    "            return {str(i): v for i, v in enumerate(obs)}\n",
    "        return obs\n",
    "\n",
    "    def handle_obs_space(space):\n",
    "        if isinstance(space, gym.spaces.Tuple):\n",
    "            if len(space) == 1:\n",
    "                return space[0]\n",
    "            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n",
    "            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})\n",
    "        return space\n",
    "\n",
    "    def create_env(env: str, worker_id: int) -> Callable[[], Env]:\n",
    "        def _f() -> Env:\n",
    "            engine_configuration_channel = EngineConfigurationChannel()\n",
    "            engine_configuration_channel.set_configuration(config.engine_config)\n",
    "            kwargs[\"side_channels\"] = kwargs.get(\"side_channels\", []) + [\n",
    "                engine_configuration_channel\n",
    "            ]\n",
    "            unity_env = _unity_env_from_path_or_registry(\n",
    "                env=env,\n",
    "                registry=config.env_registry,\n",
    "                worker_id=worker_id,\n",
    "                base_port=config.base_port,\n",
    "                seed=config.base_seed + worker_id,\n",
    "                **kwargs,\n",
    "            )\n",
    "            new_env = UnityToGymWrapper(\n",
    "                unity_env=unity_env,\n",
    "                uint8_visual=config.visual_obs,\n",
    "                allow_multiple_obs=config.allow_multiple_obs,\n",
    "            )\n",
    "            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)\n",
    "            return new_env\n",
    "\n",
    "        return _f\n",
    "\n",
    "    env_facts = [\n",
    "        create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)\n",
    "    ]\n",
    "    return SubprocVecEnv(env_facts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start Environment from the registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# -----------------\n",
    "# This code is used to close an env that might not have been closed before\n",
    "try:\n",
    "  env.close()\n",
    "except:\n",
    "  pass\n",
    "# -----------------\n",
    "\n",
    "env = make_mla_sb3_env(\n",
    "    config=LimitedConfig(\n",
    "        env_path_or_name='Basic',  # Can use any name from a registry or a path to your own unity build.\n",
    "        base_port=6006,\n",
    "        base_seed=42,\n",
    "        num_env=NUM_ENVS,\n",
    "        allow_multiple_obs=True,\n",
    "    ),\n",
    "    no_graphics=True,  # Set to false if you are running locally and want to watch the environments move around as they train.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 250K should train to a reward ~= 0.90 for the \"Basic\" environment.\n",
    "# We set the value lower here to demonstrate just a small amount of trianing.\n",
    "BATCH_SIZE = 32\n",
    "BUFFER_SIZE = 256\n",
    "UPDATES = 50\n",
    "TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES\n",
    "BETA = 0.0005\n",
    "N_EPOCHS = 3 \n",
    "STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS\n",
    "\n",
    "# Helps gather stats for our eval() calls later so we can see reward stats.\n",
    "env = VecMonitor(env)\n",
    "\n",
    "#Policy and Value function with 2 layers of 128 units each and no shared layers.\n",
    "policy_kwargs = {\"net_arch\" : [{\"pi\": [32,32], \"vf\": [32,32]}]}\n",
    "\n",
    "model = PPO(\n",
    "    \"MlpPolicy\",\n",
    "    env,\n",
    "    verbose=1,\n",
    "    learning_rate=lambda progress: 0.0003 * (1.0 - progress),\n",
    "    clip_range=lambda progress: 0.2 * (1.0 - progress),\n",
    "    clip_range_vf=lambda progress: 0.2 * (1.0 - progress),\n",
    "    # Uncomment this if you want to log tensorboard results when running this notebook locally.\n",
    "    # tensorboard_log=\"results\",\n",
    "    policy_kwargs=policy_kwargs,\n",
    "    n_steps=int(STEPS_PER_UPDATE),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    n_epochs=N_EPOCHS,\n",
    "    ent_coef=BETA,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 0.93 is considered solved for the Basic environment\n",
    "for i in range(UPDATES):\n",
    "    print(f\"Training round {i + 1}/{UPDATES}\")\n",
    "    # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.\n",
    "    model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))\n",
    "    model.policy.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h1lIx3_l24OP"
   },
   "source": [
    "### Close the environment\n",
    "Frees up the ports being used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vdWG6_SqtNtv",
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "env.close()\n",
    "print(\"Closed environment\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Colab-UnityEnvironment-1-Run.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
